package logging

import (
	"database/sql/driver"
	"fmt"
	"github.com/gin-gonic/gin"
	rotatelogs "github.com/lestrrat/go-file-rotatelogs"
	"github.com/rifflock/lfshook"
	"github.com/sirupsen/logrus"
	"os"
	"pkg/setting"
	"reflect"
	"regexp"
	"strconv"
	"time"
	"unicode"
)

var (
	sqlRegexp                = regexp.MustCompile(`\?`)
	numericPlaceHolderRegexp = regexp.MustCompile(`\$\d+`)
	log                      *logrus.Logger
)

func isPrintable(s string) bool {
	for _, r := range s {
		if !unicode.IsPrint(r) {
			return false
		}
	}
	return true
}

type Logger struct {
}

var LogFormatter = func(values ...interface{}) (messages []interface{}) {
	if len(values) > 1 {
		var (
			sql             string
			formattedValues []string
			level           = values[0]
			source          = fmt.Sprintf("[(%v)]", values[1])
		)

		messages = []interface{}{source}

		if level == "sql" {
			// duration
			messages = append(messages, fmt.Sprintf(" [%.2fms] [", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0))
			// sql

			for _, value := range values[4].([]interface{}) {
				indirectValue := reflect.Indirect(reflect.ValueOf(value))
				if indirectValue.IsValid() {
					value = indirectValue.Interface()
					if t, ok := value.(time.Time); ok {
						formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05")))
					} else if b, ok := value.([]byte); ok {
						if str := string(b); isPrintable(str) {
							formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str))
						} else {
							formattedValues = append(formattedValues, "'<binary>'")
						}
					} else if r, ok := value.(driver.Valuer); ok {
						if value, err := r.Value(); err == nil && value != nil {
							formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
						} else {
							formattedValues = append(formattedValues, "NULL")
						}
					} else {
						switch value.(type) {
						case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
							formattedValues = append(formattedValues, fmt.Sprintf("%v", value))
						default:
							formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
						}
					}
				} else {
					formattedValues = append(formattedValues, "NULL")
				}
			}

			// differentiate between $n placeholders or else treat like ?
			if numericPlaceHolderRegexp.MatchString(values[3].(string)) {
				sql = values[3].(string)
				for index, value := range formattedValues {
					placeholder := fmt.Sprintf(`\$%d([^\d]|$)`, index+1)
					sql = regexp.MustCompile(placeholder).ReplaceAllString(sql, value+"$1")
				}
			} else {
				formattedValuesLength := len(formattedValues)
				for index, value := range sqlRegexp.Split(values[3].(string), -1) {
					sql += value
					if index < formattedValuesLength {
						sql += formattedValues[index]
					}
				}
			}

			messages = append(messages, sql+"]")
			messages = append(messages, fmt.Sprintf(" [%v rows affected or returned]", strconv.FormatInt(values[5].(int64), 10)))
		} else {
			messages = append(messages, "[")
			messages = append(messages, values[2:]...)
			messages = append(messages, "]")
		}
	}

	return
}

// gorm 日志
func (logger *Logger) Print(values ...interface{}) {
	str := fmt.Sprint(LogFormatter(values...)...)
	log.Debugln(str)
}

// gin 日志
func GinLogger() gin.HandlerFunc {
	return func(c *gin.Context) {
		// 开始时间
		start := time.Now()
		// 处理请求
		c.Next()
		// 结束时间
		end := time.Now()
		// 执行时间
		latency := end.Sub(start)

		path := c.Request.URL.Path

		clientIP := c.ClientIP() + ":" + strconv.Itoa(setting.ServerSetting.HttpPort)
		method := c.Request.Method
		statusCode := c.Writer.Status()
		log.Debugf("| %3d | %13v | %15s | %s  %s |",
			statusCode,
			latency,
			clientIP,
			method,
			path,
		)
	}
}

func Debug(args ...interface{}) {
	log.Debug(args)
}

func Debugln(args ...interface{}) {
	log.Debugln(args)
}

func Debugf(format string, args ...interface{}) {
	log.Debugf(format, args)
}

func Info(args ...interface{}) {
	log.Info(args)
}

func Infoln(args ...interface{}) {
	log.Infoln(args)
}

func Infof(format string, args ...interface{}) {
	log.Infof(format, args)
}

func Warn(args ...interface{}) {
	log.Warn(args)
}

func Warnln(args ...interface{}) {
	log.Warnln(args)
}

func Warnf(format string, args ...interface{}) {
	log.Warnf(format, args)
}

func Error(args ...interface{}) {
	log.Error(args)
}

func Errorln(args ...interface{}) {
	log.Errorln(args)
}

func Errorf(format string, args ...interface{}) {
	log.Errorf(format, args)
}

func Setup() {
	log = logrus.New()
	// 输出控制台和文件
	log.SetOutput(os.Stdout)
	if setting.ServerSetting.RunMode == "debug" {
		log.SetLevel(logrus.DebugLevel)
	} else {
		log.SetLevel(logrus.InfoLevel)
	}
	apiLogPath := setting.AppSetting.RuntimeRootPath + setting.AppSetting.LogPath + "mwx-go.log"
	logWriter, _ := rotatelogs.New(
		apiLogPath+".%Y-%m-%d.log",                // %Y-%m-%d %H-%M-%S
		rotatelogs.WithLinkName(apiLogPath),       // 生成软链，指向最新日志文件
		rotatelogs.WithMaxAge(-1),                 // 文件最大保存时间 7*24*time.Hour
		rotatelogs.WithRotationTime(24*time.Hour), // 日志切割时间间隔
	)
	writeMap := lfshook.WriterMap{
		logrus.FatalLevel: logWriter,
		logrus.ErrorLevel: logWriter,
		logrus.WarnLevel:  logWriter,
		logrus.InfoLevel:  logWriter,
		logrus.DebugLevel: logWriter,
		logrus.TraceLevel: logWriter,
	}
	lfHook := lfshook.NewHook(writeMap, &logrus.JSONFormatter{
		TimestampFormat: setting.AppSetting.DateFormat,
	})
	log.AddHook(lfHook)
}
